/**
 * Copyright (C) 2015 Baifendian Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.examples.streaming

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Seconds, StreamingContext, Time}
import org.apache.spark.util.IntParam
import org.apache.spark.{Accumulator, SparkConf, SparkContext}
import redis.RedisClient

object WordBlacklist {
  @volatile private var instance: Broadcast[Seq[String]] = null

  // 广播变量，为了避免 driver fail，用了单件模式
  def getInstance(sc: SparkContext): Broadcast[Seq[String]] = {
    if (instance == null) {
      synchronized {
        if (instance == null) {
          val wordBlacklist = Seq("a", "b", "c")
          instance = sc.broadcast(wordBlacklist)
        }
      }
    }
    instance
  }
}

/**
 * Use this singleton to get or register an Accumulator.
 */
object DroppedWordsCounter {
  @volatile private var instance: Accumulator[Long] = null

  def getInstance(sc: SparkContext): Accumulator[Long] = {
    if (instance == null) {
      synchronized {
        if (instance == null) {
          instance = sc.accumulator(0L, "WordsInBlacklistCounter")
        }
      }
    }
    instance
  }
}

object RedisAddress {
  @volatile private var instance: Broadcast[(String, Int)] = null

  // 广播变量，为了避免 driver fail，用了单件模式
  def getInstance(sc: SparkContext, ip: String, port: Int): Broadcast[(String, Int)] = {
    if (instance == null) {
      synchronized {
        if (instance == null) {
          instance = sc.broadcast((ip, port))
        }
      }
    }
    instance
  }
}

object RedisConnection {
  var ip: String = "localhost"
  var port: Int = 6379
  var redis: Option[RedisClient] = None

  def setAddress(ip: String, port: Int) = {
    if (this.ip != ip || this.port != port) {
      this.ip = ip
      this.port = port

      implicit val akkaSystem = akka.actor.ActorSystem()

      redis = Some(RedisClient(ip, port))
    }
  }

  def getConnection = {
    implicit val akkaSystem = akka.actor.ActorSystem()

    redis.getOrElse(RedisClient(ip, port))
  }
}

/**
 * 这里显示了如何使用 lazily 加载的 singleton instances，用于 Accumulator 和 Broadcast，这样能够应对 driver 失败。
 *
 */
object RecoverableNetworkWordCount {
  def createContext(checkpointDirectory: String)
  : StreamingContext = {
    // If you do not see this printed, that means the StreamingContext has been loaded
    // from the new checkpoint
    println("Creating new context")

    val sparkConf = new SparkConf().setAppName("RecoverableNetworkWordCount")

    // Create the context with a 5 second batch size
    val ssc = new StreamingContext(sparkConf, Seconds(5))
    ssc.checkpoint(checkpointDirectory)

    ssc
  }

  def main(args: Array[String]) {
    if (args.length != 5) {
      System.err.println("Your arguments were " + args.mkString("[", ", ", "]"))
      System.err.println(
        """
          |Usage: <hostname> <port> <checkpoint-directory> <redis-hostname> <redis-port>.
          |     <hostname> and <port> describe the TCP server that Spark
          |     Streaming would connect to receive data.
          |     <checkpoint-directory> directory to HDFS-compatible file system which checkpoint data.
          |     <redis-hostname> <redis-port> describe the redis where statistic information written to.
        """.stripMargin
      )

      System.exit(1)
    }

    val Array(ip, IntParam(port), checkpointDirectory, redisIp, IntParam(redisPort)) = args

    val ssc = StreamingContext.getOrCreate(checkpointDirectory,
      () => {
        createContext(checkpointDirectory)
      })

    // Create a socket stream on target ip:port and count the
    // words in input stream of \n delimited text (eg. generated by 'nc')
    val lines = ssc.socketTextStream(ip, port)
    val words = lines.flatMap(_.split(" "))
    val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)

    wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => {
      // 黑名单列表
      val blacklist = WordBlacklist.getInstance(rdd.sparkContext)

      // redis 的地址信息，广播变量
      val address = RedisAddress.getInstance(rdd.sparkContext, redisIp, redisPort)

      // 累加器，统计被扔掉的单词
      val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext)

      // Use blacklist to drop words and use droppedWordsCounter to count them
      rdd.filter { case (word, count) =>
        if (blacklist.value.contains(word)) {
          droppedWordsCounter += count
          false
        } else {
          true
        }
      }.foreachPartition { partitionOfRecords =>
        // ConnectionPool is a static, lazily initialized pool of connections
        RedisConnection.setAddress(address.value._1, address.value._2)
        val connection = RedisConnection.getConnection

        partitionOfRecords.foreach(record => connection.set(record._1, record._2, Some(10)))
      }

      println("Dropped " + droppedWordsCounter.value + " word(s) totally")
    })

    ssc.start()
    ssc.awaitTermination()
  }
}